import math
import numpy as np
import cv2
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
import os
import sys

print(sys.path)
print(os.getcwd())
import torch
import torch.nn as nn
from datasets.table.dataloader import LoadTableImageAndLabels
from models.assembly.segmentation_table import Segmentation_Model
from models.assembly.deeplab import DeepLabV3
from models.assembly.my_pan_model import PanModel
from dice_loss import dice_coeff


def save_tensor(tensor, i, e):
    np_array = tensor[0].detach().cpu().numpy().transpose(1, 2, 0)
    #     mx, mn = np_array.max(), np_array.min()
    #     arr = (np_array - mn) / (mx - mn) * 255
    np_array = np.array(np_array * 255, np.uint8)[:, :, 0]
    # np_array = cv2.resize(np_array, (320, 200))
    cv2.imwrite('results/' + str(e) + '-' + str(i) + '.jpg', np_array)


def _iou(pred, target, size_average=True):
    b = pred.shape[0]
    IoU = 0.0
    for i in range(0, b):
        # compute the IoU of the foreground
        Iand1 = torch.sum(target[i, :, :, :] * pred[i, :, :, :])
        Ior1 = torch.sum(target[i, :, :, :]) + torch.sum(pred[i, :, :, :]) - Iand1
        IoU1 = Iand1 / Ior1

        # IoU loss is (1-IoU1)
        IoU = IoU + (1 - IoU1)

    return IoU / b


class IOU(torch.nn.Module):
    def __init__(self, size_average=True):
        super(IOU, self).__init__()
        self.size_average = size_average

    def forward(self, pred, target):
        return _iou(pred, target, self.size_average)


class BCEFocalLoss(torch.nn.Module):

    def __init__(self, gamma=2, alpha=0.6, reduction='elementwise_mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, _input, target):
        _input[_input <= 1e-4] = 1e-4
        alpha = self.alpha
        loss = - alpha * (1 - _input) ** self.gamma * target * torch.log(_input) - \
               (1 - alpha) * _input ** self.gamma * (1 - target) * torch.log(1 - _input)
        if self.reduction == 'elementwise_mean':
            loss = loss.mean()
        elif self.reduction == 'sum':
            loss = loss.mean()
        return loss


def balance_mask(score, label, mask):
    pos_num = label[label > 0.5].numel()  # 返回元数个数
    selected_mask = torch.zeros_like(label)
    if pos_num == 0:
        selected_mask = torch.ones_like(label)
        return selected_mask
    selected_mask[label > 0.5] = 1.0
    # 正负样本比例为1：3
    #     neg_num = label[label <= 0.5].numel()
    #     neg_num = (int)(min(pos_num * 3, neg_num))
    # 负样本在正例的基础上膨胀几个像素点。取个数
    neg_num = mask[mask > 0.5].numel()
    if neg_num == 0:
        return selected_mask
    neg_score = score[label <= 0.5]
    neg_score_sorted = torch.sort(-neg_score)
    threshold = -neg_score_sorted[0][neg_num - 1]
    selected_mask[score >= threshold] = 1.0
    #
    #     selected_mask[score <= 0.5] = 0.0
    #     total_num = selected_mask[selected_mask > 0.0].numel()
    return selected_mask, pos_num + neg_num


#     return selected_mask, total_num
def pos_hard_mining(outputs, targets, mask):
    # 第一种方式
    balan_mask, total_num = balance_mask(outputs, targets, mask)
    loss = F.binary_cross_entropy(outputs, targets, reduction='none')
    # pos_loss = loss.mul(mask).mean()#与loss做点乘之后去均值
    # 第二种方式：
    loss = ((loss * 3).mul(balan_mask).sum() / total_num).mean()
    #     loss = (3*((loss * 2).mul(balan_mask).sum() / total_num).mean()+loss.mean()).mean()

    return loss


def hard_mining(outputs, targets):
    #     loss = nn.CrossEntropyLoss(outputs, targets)
    #     return loss
    loss = F.binary_cross_entropy(outputs, targets, reduction='none')
    _, topk_loss_inds = loss.reshape(-1).topk(loss.reshape(-1).numel() // 2)
    return loss.reshape(-1)[topk_loss_inds].mean()


class SoftDiceLoss(torch.nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(SoftDiceLoss, self).__init__()

    def forward1(self, logits, targets):
        num = targets.size(0)
        smooth = 1

        # probs = F.sigmoid(logits)
        m1 = logits.view(num, -1)
        m2 = targets.view(num, -1)
        intersection = (m1 * m2)

        score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
        score = 1 - score.sum() / num
        return score

    def forward(self, logits, targets):
        num = targets.size(0)
        smooth = 1
        #         logits[logits<0.5] = 0.0
        # probs = F.sigmoid(logits)
        m1 = logits.view(num, -1)
        m2 = targets.view(num, -1)
        intersection = (m1 * m2)
        score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
        score = 1 - score.sum() / num
        return score


class TverskyLoss(torch.nn.Module):
    def __init__(self, weight=None, size_average=True):  # https://zhuanlan.zhihu.com/p/103426335
        super(TverskyLoss, self).__init__()

    def forward(self, logits, targets):
        num = targets.size(0)
        smooth = 1

        # probs = F.sigmoid(logits)
        m1 = logits.view(num, -1)
        m2 = targets.view(num, -1)
        true_pos = (m2 * m1).sum(1)
        false_neg = (m2 * (1 - m1)).sum(1)
        false_pos = ((1 - m2) * m1).sum(1)

        alpha = 0.7
        score = (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)
        return score.sum() / num
    # DeepLabV3网络模型


def main():
    #     data_fd = '/mnt/data/ocr_data/table_data/merge_data_mask_v1'
    data_fd = '/mnt/data/ocr_data/table_data/contract_bank_merge_860_860_v1'
    # data_fd = '/mnt/data/xp/datasets/table/images'
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    #     device = torch.device('cpu')
    dataset = LoadTableImageAndLabels(data_fd)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=24, num_workers=0, shuffle=True, pin_memory=True)
    model = DeepLabV3(1).to(device)
    #     model.load_state_dict(torch.load('./ckpt/hard_mining_v1.3_190.pth'))
    loss_fn = BCEFocalLoss()
    loss_softdice = SoftDiceLoss()
    loss_tversky = TverskyLoss()
    epochs = 20000000

    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-4)

    lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1  # cosine
    # scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    # scheduler.last_epoch = epochs - 1  # do not move
    #     scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[64, 128, 168], gamma=0.1)
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[3, 7, 13], gamma=0.1)
    for epoch in range(0, epochs):
        model.train()
        print('epoch: ', epoch)
        for i, (imgs, targets, mask) in enumerate(dataloader):
            imgs = imgs.to(device).float()
            outputs = model(imgs).sigmoid()
            loss1 = pos_hard_mining(outputs, targets.to(device), mask.to(device))
            loss2 = loss_softdice(outputs, targets.to(device))
            loss = (loss1 + loss2 * 3).mean()
            #             if epoch >= 0 and epoch < 200:
            #                 loss = pos_hard_mining(outputs, targets.to(device))
            #             elif epoch >= 200:
            # #                 loss = loss_fn(outputs, targets.to(device))
            # #                 loss = hard_mining(outputs, targets.to(device))
            # #                 outputs = (outputs > 0.5).float()
            # #                 loss = dice_coeff(outputs, targets.to(device)).mean()
            #                 loss = pos_hard_mining(outputs, targets.to(device))
            # #                 loss.requires_grad = True
            #             else:
            #                 loss = F.binary_cross_entropy(outputs, targets.to(device))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if i % 10 == 0:
                save_tensor(outputs, i, epoch)
                print('loss is:', loss.cpu().item())
                # print(i, pred.max(), pred.min(), loss)
        scheduler.step()
        if epoch and epoch % 10 == 0:
            torch.save(model.state_dict(), './ckpt/hard_mining_v1.9_' + str(epoch) + '.pth')
            # 1.4验证把负样本设置为正样本周围的个数
            # 1.5 验证把predect结果小于0.5的部分全部置0
            # 1.6 loss全部替换成softdiceloss
            # 1.7 loss全部替换成 TverskyLoss
            # 1.8 loss=正负样本比例1:3*0.8+softdiceloss*0.2
            # 1.9 loss=正负样本比例1:3*0.5+softdiceloss*0.5
            # 2.0 loss=正负样本比例1:3*3+softdiceloss
            # 2.1 loss=正负样本比例1:3+softdiceloss*3


# pannet网络模型
def main_pannet():
    data_fd = '/mnt/data/ocr_data/table_data/table_labelme_v1'
    # data_fd = '/mnt/data/xp/datasets/table/images'
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    #     device = torch.device('cpu')
    dataset = LoadTableImageAndLabels(data_fd)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=10, shuffle=True, pin_memory=True)
    #     model = DeepLabV3(1).to(device)
    model = PanModel().to(device)
    #     model.load_state_dict(torch.load('./ckpt/pan_net_resnet50_gc_860_v1.5_140.pth'))
    loss_fn = BCEFocalLoss()
    loss_softdice = SoftDiceLoss()
    loss_tversky = TverskyLoss()
    loss_iou = IOU()
    epochs = 20000000

    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=5e-4)

    lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.9 + 0.1  # cosine
    # scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    # scheduler.last_epoch = epochs - 1  # do not move
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[30, 60, 100], gamma=0.1)
    #     scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[3, 7, 13], gamma=0.1)
    for epoch in range(0, epochs):
        model.train()
        print('epoch: ', epoch)
        for i, (imgs, targets, mask) in enumerate(dataloader):
            imgs = imgs.to(device).float()
            outputs = model(imgs).sigmoid()
            loss1 = pos_hard_mining(outputs, targets.to(device), mask.to(device))
            loss2 = loss_softdice(outputs, targets.to(device))
            loss3 = loss_iou(outputs, targets.to(device))
            #             loss = (loss1+loss2*0).mean()#难虑
            loss = (loss1 * 3 + loss2 * 2 + loss3).mean()  # dice
            #             if epoch >= 0 and epoch < 200:
            #                 loss = pos_hard_mining(outputs, targets.to(device))
            #             elif epoch >= 200:
            # #                 loss = loss_fn(outputs, targets.to(device))
            # #                 loss = hard_mining(outputs, targets.to(device))
            # #                 outputs = (outputs > 0.5).float()
            # #                 loss = dice_coeff(outputs, targets.to(device)).mean()
            #                 loss = pos_hard_mining(outputs, targets.to(device))
            # #                 loss.requires_grad = True
            #             else:
            #                 loss = F.binary_cross_entropy(outputs, targets.to(device))
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if i % 10 == 0:
                save_tensor(outputs, i, epoch)
                print('loss is:', loss.cpu().item())
                # print(i, pred.max(), pred.min(), loss)
        scheduler.step()
        if epoch and epoch % 10 == 0:
            torch.save(model.state_dict(), './ckpt/pan_net_resnet50_gc_860_v1.5.1_' + str(epoch) + '.pth')



if __name__ == '__main__':
    main_pannet()
